In this notebook, we use the first 5 MNIST digits to compare the different approaches we have to approximate the Riemannian distance:
Stochastic Riemannian length of shortest curves found:
Time per 100 geodesic approximations:
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
from copy import deepcopy
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
# Set up plotly
init_notebook_mode(connected=True)
layout = go.Layout(
width=700,
height=500,
margin=go.Margin(l=60, r=60, b=40, t=20),
showlegend=False
)
config={'showLink': False}
# Make results completely repeatable
seed = 0
np.random.seed(seed)
tf.set_random_seed(seed)
digit_classes = [0,1,2,3,4]
following the implementation details in appendix D in the Latent Space Oddity paper.
from tensorflow.python.keras import Sequential, Model
from tensorflow.python.keras.layers import Dense, Input, Lambda
from src.vae import VAE
from src.rbf import RBFLayer
# Implementation details from Appendix D
input_dim = 784
latent_dim = 2
l2_reg = tf.keras.regularizers.l2(1e-5)
# Create the encoder models
enc_input = Input((input_dim,))
enc_shared = Dense(64, activation='tanh', kernel_regularizer=l2_reg)
enc_mean = Sequential([
enc_shared,
Dense(32, activation='tanh', kernel_regularizer=l2_reg),
Dense(latent_dim, activation='linear', kernel_regularizer=l2_reg)
])
enc_var = Sequential([
enc_shared,
Dense(32, activation='tanh', kernel_regularizer=l2_reg),
Dense(latent_dim, activation='softplus', kernel_regularizer=l2_reg)
])
enc_mean = Model(enc_input, enc_mean(enc_input))
enc_var = Model(enc_input, enc_var(enc_input))
# Create the decoder models
dec_input = Input((latent_dim,))
dec_mean = Sequential([
Dense(32, activation='tanh', kernel_regularizer=l2_reg),
Dense(64, activation='tanh', kernel_regularizer=l2_reg),
Dense(input_dim, activation='sigmoid', kernel_regularizer=l2_reg)
])
dec_mean = Model(dec_input, dec_mean(dec_input))
# Build the RBF network
num_centers = 64
a = 1.0
rbf = RBFLayer([input_dim], num_centers)
dec_var = Model(dec_input, rbf(dec_input))
vae = VAE(enc_mean, enc_var, dec_mean, dec_var, dec_stddev=1.)
from tensorflow.python.keras.datasets import mnist
# Train the VAE on MNIST digits
(x_train_all, y_train_all), _ = mnist.load_data()
# Filter the digit classes from the mnist data
x_train = []
y_train = []
for digit_class in digit_classes:
for x, y in zip(x_train_all, y_train_all):
if y == digit_class:
x_train.append(x)
y_train.append(y)
x_train = np.array(x_train).astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
y_train = np.array(y_train)
# Shuffle the data
p = np.random.permutation(len(x_train))
x_train = x_train[p]
y_train = y_train[p]
without training the generator's variance network. This will be trained separately later.
history = vae.model.fit(x_train,
epochs=100,
batch_size=32,
validation_split=0.1,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'], name='Train Loss'),
go.Scatter(y=history.history['val_loss'], name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
# Display a 2D plot of the classes in the latent space
encoded_sampled, encoded_mean, encoded_var = vae.encoder.predict(x_train)
# Plot
scatter_data = []
colors = ['#825446', '#26A69A', '#F7B554', '#2D4366', '#84439A']
plot_points = []
for i_class, digit_class in enumerate(digit_classes):
# Filter 500 points of this class from the training data
class_hits = [y == digit_class for y in y_train]
class_indices = np.arange(len(class_hits))[class_hits]
class_indices = class_indices[:500]
x_class = encoded_mean[class_indices]
plot_points.extend(x_class)
# Plot
scatter_data.append(go.Scatter(
x = x_class[:, 0],
y = x_class[:, 1],
mode = 'markers',
marker = {'color': colors[i_class]},
name = digit_class,
hoverinfo = 'text',
text = class_indices
))
iplot(go.Figure(data=scatter_data, layout=layout), config=config)
plot_points = np.array(plot_points)
For this, we first have to find the centers of the latent points.
from sklearn.cluster import KMeans
# Find the centers of the latent representations
kmeans_model = KMeans(n_clusters=num_centers, random_state=0)
kmeans_model = kmeans_model.fit(encoded_mean)
centers = kmeans_model.cluster_centers_
# Visualize the centers
center_plot = go.Scatter(
x = centers[:, 0],
y = centers[:, 1],
mode = 'markers',
marker = {'color': 'red'}
)
data = scatter_data + [center_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
# Cluster the latent representations
clustering = dict((c_i, []) for c_i in range(num_centers))
for z_i, c_i in zip(encoded_mean, kmeans_model.predict(encoded_mean)):
clustering[c_i].append(z_i)
bandwidths = []
for c_i, cluster in clustering.items():
if cluster:
diffs = np.array(cluster) - centers[c_i]
avg_dist = np.mean(np.linalg.norm(diffs, axis=1))
bandwidth = 0.5 / (a * avg_dist)**2
else:
bandwidth = 0
bandwidths.append(bandwidth)
bandwidths = np.array(bandwidths)
# Train the RBF
vae.recompile_for_var_training()
rbf_kernel = rbf.get_weights()[0]
rbf.set_weights([rbf_kernel, centers, bandwidths])
history = vae.model.fit(x_train,
epochs=100,
batch_size=32,
validation_split=0.1,
verbose=0)
# Plot the losses
data = [go.Scatter(y=history.history['loss'],
name='Train Loss'),
go.Scatter(y=history.history['val_loss'],
name='Validation Loss')]
plot_layout = deepcopy(layout)
plot_layout['xaxis'] = {'title': 'Epoch'}
plot_layout['yaxis'] = {'title': 'NELBO'}
plot_layout['showlegend'] = True
iplot(go.Figure(data=data, layout=plot_layout), config=config)
from src.util import wrap_model_in_float64
# Get the mean and std predictors
_, mean_output, var_output = vae.decoder.output
sqrt_layer = Lambda(tf.sqrt)
dec_mean = Model(vae.decoder.input, mean_output)
dec_std = Model(vae.decoder.input, sqrt_layer(var_output))
dec_mean = wrap_model_in_float64(dec_mean)
dec_std = wrap_model_in_float64(dec_std)
session = tf.keras.backend.get_session()
for finding a geodesic.
z_start_index = 1586
z_end_index = 2051
z_start, z_end = plot_points[[z_start_index, z_end_index]]
# Visualize the centers
task_plot = go.Scatter(
x = [z_start[0], z_end[0]],
y = [z_start[1], z_end[1]],
mode = 'markers',
marker = {'color': 'd32f2f'}
)
data = scatter_data + [task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
from src.plot import plot_magnification_factor
heatmap_z1 = np.linspace(-5, 5, 100)
heatmap_z2 = np.linspace(-5, 5, 100)
heatmap = plot_magnification_factor(session,
heatmap_z1,
heatmap_z2,
dec_mean,
dec_std,
additional_data=scatter_data + [task_plot],
layout=layout,
log_scale=True)
Before we start comparing the geodesic approximations, we need to define the metric. For each curve, we take equidistant steps in the latent space in order to compute the Riemannian length using numerical integration. We also plot the curve velocity.
from src.util import get_length_op, get_lengths_op, interpolate
curve_ph = tf.placeholder(tf.float64, [None, 2])
length_op, _ = get_length_op(curve_ph, dec_mean, dec_std)
lengths_op = get_lengths_op(curve_ph, dec_mean, dec_std)
lengths_op = tf.squeeze(lengths_op)
def evaluate_curve(curve, num_nodes=200, with_velocity_plot=True,
verbose=True):
curve = interpolate(curve, num_nodes)
lengths = session.run(lengths_op, feed_dict={curve_ph: curve})
length = np.sum(lengths)
if verbose:
print('Curve length: ', length)
if with_velocity_plot:
plot_velocity(lengths)
return length
def plot_velocity(lengths):
num_nodes = len(lengths)
velocities = lengths * (num_nodes - 1)
trace = go.Scatter(
x = np.linspace(0, 1, num_nodes),
y = velocities
)
iplot(go.Figure(data=[trace], layout=go.Layout(
width=700,
height=100,
margin=go.Margin(l=60, r=60, b=20, t=20),
showlegend=False
)), config=config)
t_nodes = np.linspace(0, 1, 50)
euclidean_curve = z_start + np.outer(t_nodes, z_end - z_start)
evaluate_curve(euclidean_curve)
%%time
from src.discrete import find_geodesic_discrete
discrete_curve, discrete_iterations = find_geodesic_discrete(
session,
z_start, z_end,
dec_mean,
std_generator=dec_std,
num_nodes=50,
max_steps=400,
learning_rate=0.01,
log_every=50,
save_every=30)
print('-' * 20)
evaluate_curve(discrete_curve)
Like in 14-graph-geodesics-moons.ipynb, the discrete geodesic algorithm's length estimate is strongly biased due to jumps over regions of large Riemannian metrics. Therefore, the actual curve length, 116, is higher than the length it converges to (43.4).
The ODE often does not converge. If it does, it takes orders of magnitue longer than the discrete geodesic algorithm without reaching a better solution. This is because the ODE only finds local minima just like the discrete geodesic algorithm. Neither of those methods searches globally for a solution.
%%time
from src.geodesic import find_geodesic
ode_result, ode_iterations = find_geodesic(session, z_start, z_end,
dec_mean, std_generator=dec_std,
initial_nodes=20, max_nodes=1000,
use_fun_jac=True)
print('-' * 20)
from src.plot import plot_latent_curve_iterations
plot_latent_curve_iterations(ode_iterations[::10], [heatmap] + scatter_data,
layout, step_size=10)
ode_curve = ode_result.sol(ode_result.x)[0:2].T
evaluate_curve(ode_curve)
We use the 500 points per digit from the latent plots above and add three times as many random gaussian noise points. This gives a total of 10,000 points in the latent space, which we will use for our graph in the latent space
extensions = [plot_points + np.random.randn(*plot_points.shape)
for _ in range(3)]
graph_points = np.concatenate([plot_points] + extensions)
print(graph_points.shape)
To get the nearest neighbors of each point, we use the get_neighbors function from src.graph. It is explained and defined in 13-graph-geodesics.ipynb.
Given the get_neighbors function, compute the Riemannian distance between each point and each of its neighbors. We approximate the Riemannian distance with a single midpoint for integration: $\int_0^1 \left\| J_{\gamma_t} \dot{\gamma}_t \right\| \mathrm{d}t \approx \left\| J_{\gamma_t} \dot{\gamma}_t \right\|$
import networkx as nx
from tqdm import tqdm
from src.util import get_metric_op
from src.graph import get_neighbors
point_ph = tf.placeholder(tf.float64, [2])
metric_op = get_metric_op(point_ph, dec_mean, dec_std)
# Compute the distance between the kNNs in Euclidean space
k = 4
graph = nx.Graph()
for i_point, point in enumerate(graph_points):
graph.add_node(i_point, pos=point)
for i_point, point in enumerate(tqdm(graph_points)):
neighbor_indices = get_neighbors(i_point, graph_points, k)
for i_neighbor in neighbor_indices:
if graph.has_edge(i_neighbor, i_point):
continue
neighbor = graph_points[i_neighbor]
middle = point + 0.5 * (neighbor - point)
velocity = neighbor - point
metric = session.run(metric_op, feed_dict={point_ph: middle})
length = velocity.T.dot(metric).dot(velocity)
length = np.sqrt(length)
graph.add_edge(i_point, i_neighbor, weight=length)
and the relative weight of the edges (Riemannian length divided by Euclidean length). Green means a low relative weight, red means a large relative weight.
from src.plot import plot_graph_with_edge_colors, plot_graph
x_range = [min(z_start[0], z_end[0]), max(z_start[0], z_end[0])]
y_range = [min(z_start[1], z_end[1]), max(z_start[1], z_end[1])]
subnodes = []
for node in graph.nodes():
pos = graph.node[node]['pos']
if (x_range[0] <= pos[0] <= x_range[1] and
y_range[0] <= pos[1] <= y_range[1]):
subnodes.append(node)
subgraph = graph.subgraph(subnodes)
graph_plot = plot_graph_with_edge_colors(graph, layout=layout,
additional_data=[task_plot])
between the two points from above.
%%time
from networkx.algorithms.shortest_paths.generic import shortest_path
path = shortest_path(graph, z_start_index, z_end_index, weight='weight') #2
length = 0
for source, sink in zip(path[:-1], path[1:]):
length += graph[source][sink]['weight']
print('Path length:', length)
print('-' * 20)
from src.plot import plot_graph
# Construct a subgraph from the path
path_graph = nx.Graph()
for point in path:
path_graph.add_node(point, pos=graph_points[point])
for source, sink in zip(path[:-1], path[1:]):
weight = graph[source][sink]['weight']
path_graph.add_edge(source, sink, weight=weight)
_ = plot_graph(path_graph, layout=layout, edge_color='#00DD00',
node_color='#00DD00', additional_data=[heatmap] + scatter_data)
Since we only computed the Riemannian distance for each edge using a single midpoint, the graph length is not exactly correct. It is not as strongly biased as the discrete geodesic algorithm's length estimate, but we should measure it as well with the interpolate function for a fair comparison.
graph_curve = graph_points[path]
evaluate_curve(graph_curve)
# Plot the graph curve
graph_curve_plot = go.Scatter(
x=graph_curve[:, 0],
y=graph_curve[:, 1],
mode='lines',
line={'width': 5, 'color': '#3CA64D'}
)
# Plot the discrete curve
discrete_curve_plot = go.Scatter(
x=discrete_curve[:, 0],
y=discrete_curve[:, 1],
mode='lines',
line={'width': 5, 'color': '#d32f2f'}
)
data = [heatmap] + scatter_data + [graph_curve_plot,
discrete_curve_plot, task_plot]
iplot(go.Figure(data=data, layout=layout), config=config)
graphref_curve, _ = find_geodesic_discrete(
session,
z_start, z_end,
dec_mean,
std_generator=dec_std,
num_nodes=50,
max_steps=50,
learning_rate=0.01,
log_every=10,
save_every=30,
curve_init=graph_curve)
evaluate_curve(graphref_curve)
plot_latent_curve_iterations([graphref_curve], [heatmap] + scatter_data, layout)
Measure the runtime of each approach on 100 random pairs of points. We don't use the ODE here, since it takes orders of magnitude longer than the discrete geodesic algorithm without giving better geodesic approximations.
z_starts = np.random.choice(len(plot_points), 100)
z_ends = np.random.choice(len(plot_points), 100)
def test_euclidean(z_starts, z_ends):
curve_ph = tf.placeholder(tf.float64, [None, 2])
length_op = get_length_op(curve_ph, dec_mean, dec_std)
curves = []
lengths = []
for z_start, z_end in zip(plot_points[z_starts], plot_points[z_ends]):
t_nodes = np.linspace(0, 1, 20)
curve = z_start + np.outer(t_nodes, z_end - z_start)
length, _ = session.run(length_op, feed_dict={curve_ph: curve})
curves.append(curve)
lengths.append(length)
return curves, lengths
from src.discrete import find_geodesics_discrete
def test_discrete(z_starts, z_ends):
return find_geodesics_discrete(
session,
plot_points[z_starts], plot_points[z_ends],
dec_mean,
std_generator=dec_std,
num_nodes=50,
max_steps=400,
learning_rate=0.01)
def test_graph(z_starts, z_ends):
curves = []
lengths = []
for z_start, z_end in zip(z_starts, z_ends):
path = shortest_path(graph, z_start, z_end, weight='weight')
curve = graph_points[path]
length = 0
for source, sink in zip(path[:-1], path[1:]):
length += graph[source][sink]['weight']
curves.append(curve)
lengths.append(length)
return curves, lengths
def test_graph_refinement(z_starts, z_ends):
curve_inits = []
for z_start, z_end in zip(z_starts, z_ends):
path = shortest_path(graph, z_start, z_end, weight='weight')
curve = graph_points[path]
curve_inits.append(curve)
return find_geodesics_discrete(
session,
encoded_mean[z_starts], encoded_mean[z_ends],
dec_mean,
std_generator=dec_std,
num_nodes=50,
max_steps=20,
learning_rate=0.01,
curve_inits=curve_inits)
%%time
eucl_curves, eucl_est_lengths = test_euclidean(z_starts, z_ends)
print('-' * 20)
%%time
discrete_curves, discrete_est_lengths = test_discrete(z_starts, z_ends)
print('-' * 20)
%%time
graph_curves, graph_est_lengths = test_graph(z_starts, z_ends)
print('-' * 20)
%%time
graphref_curves, graphref_est_lengths = test_graph_refinement(z_starts, z_ends)
print('-' * 20)
We see that the discrete geodesic algorithm gives strongly biased length estimates.
def evaluate_curves(curves, estimated_lengths, num_nodes=200):
lengths = []
estimation_errors = []
for curve, estimated_length in zip(curves, estimated_lengths):
curve = interpolate(curve, num_nodes)
length = session.run(length_op, feed_dict={curve_ph: curve})
lengths.append(length)
estimation_errors.append(estimated_length - length)
print('Estimation error mean: ', np.mean(estimation_errors))
print('Estimation error std: ', np.std(estimation_errors))
return lengths
eucl_lengths = evaluate_curves(eucl_curves, eucl_est_lengths)
discrete_lengths= evaluate_curves(discrete_curves, discrete_est_lengths)
graph_lengths = evaluate_curves(graph_curves, graph_est_lengths)
graphref_lengths = evaluate_curves(graphref_curves, graphref_est_lengths)
eucl_trace = go.Scatter(
x = graph_lengths,
y = np.array(eucl_lengths),
mode = 'markers',
marker = {'size': 8, 'symbol': 'x', 'color': 'orange'},
name = 'Euclidean'
)
graph_trace = go.Scatter(
x = graph_lengths,
y = np.array(graph_lengths),
mode = 'markers',
marker = {'size': 8, 'symbol': 'x', 'color': '#3CA8FF'},
name = 'Graph '
)
discrete_trace = go.Scatter(
x = graph_lengths,
y = np.array(discrete_lengths),
mode = 'markers',
marker = {'size': 8, 'symbol': 'x', 'color': '#d32f2f'},
name = 'Discrete'
)
data = [eucl_trace, graph_trace, discrete_trace]
_layout = go.Layout(
width=800,
height=600,
margin=go.Margin(l=60, r=60, b=40, t=20),
xaxis={
'title': 'Length of graph solution',
'titlefont': {'size': 18}
},
yaxis={
'title': 'Stochastic Riemannian length',
'titlefont': {'size': 18}
},
legend={
'font': {'size': 18}
}
)
iplot(go.Figure(data=data, layout=_layout), config=config)
graphref_trace = go.Scatter(
x = graph_lengths,
y = np.array(graphref_lengths),
mode = 'markers',
marker = {'size': 8, 'symbol': 'x', 'color': '#2D4366'},
name = 'Graph Refinement'
)
data = [eucl_trace, graph_trace, discrete_trace, graphref_trace]
iplot(go.Figure(data=data, layout=_layout), config=config)
Stochastic Riemannian length of shortest curves found:
Time per 100 geodesic approximations: